{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"A very simple MNIST classifier.\n",
    "See extensive documentation at\n",
    "https://www.tensorflow.org/get_started/mnist/beginners\n",
    "\"\"\"\n",
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import argparse\n",
    "import sys\n",
    "\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "FLAGS = None\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们在这里调用系统提供的Mnist数据函数为我们读入数据，如果没有下载的话则进行下载。\n",
    "\n",
    "<font color=#ff0000>**这里将data_dir改为适合你的运行环境的目录**</font>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /tmp/tensorflow/mnist/input_data\\train-images-idx3-ubyte.gz\n",
      "Extracting /tmp/tensorflow/mnist/input_data\\train-labels-idx1-ubyte.gz\n",
      "Extracting /tmp/tensorflow/mnist/input_data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting /tmp/tensorflow/mnist/input_data\\t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "# Import data\n",
    "data_dir = '/tmp/tensorflow/mnist/input_data'\n",
    "mnist = input_data.read_data_sets(data_dir, one_hot=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一个非常非常简陋的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define loss and optimizer\n",
    "x = tf.placeholder(tf.float32, shape=[None, 784])\n",
    "y_ = tf.placeholder(tf.float32, [None, 10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个函数，用于初始化所有的权值 W\n",
    "def weight_variable(shape):\n",
    "    initial = tf.truncated_normal(shape, stddev=0.1)\n",
    "    return tf.Variable(initial)\n",
    "# 定义一个函数，用于初始化所有的偏置项 b\n",
    "def bias_variable(shape):\n",
    "    initial = tf.constant(0.1, shape=shape)\n",
    "    return tf.Variable(initial)\n",
    "# 定义一个函数，用于构建卷积层\n",
    "def conv2d(x, W):\n",
    "    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')\n",
    "# 定义一个函数，用于构建池化层\n",
    "def max_pool(x):\n",
    "    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n",
    "\n",
    "x_image = tf.reshape(x, [-1, 28, 28, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建网络\n",
    "W_conv1 = weight_variable([5, 5, 1, 32])\n",
    "b_conv1 = bias_variable([32])\n",
    "\n",
    "h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)  # 第一个卷积层\n",
    "h_pool1 = max_pool(h_conv1)  # 第一个池化层\n",
    "\n",
    "W_conv2 = weight_variable([5, 5, 32, 64])\n",
    "b_conv2 = bias_variable([64])\n",
    "h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)  # 第二个卷积层\n",
    "h_pool2 = max_pool(h_conv2)  # 第二个池化层\n",
    "\n",
    "W_fc1 = weight_variable([7 * 7 * 64, 1024])\n",
    "b_fc1 = bias_variable([1024])\n",
    "h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])  # reshape成向量\n",
    "h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)  # 全连接层\n",
    "\n",
    "keep_prob = tf.placeholder(\"float\")\n",
    "h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)  # dropout层\n",
    "\n",
    "W_fc2 = weight_variable([1024, 10])#第二层全连接\n",
    "b_fc2 = bias_variable([10])\n",
    "y = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)  # softmax层"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来我们计算交叉熵，注意这里不要使用注释中的手动计算方式，而是使用系统函数。\n",
    "另一个注意点就是，softmax_cross_entropy_with_logits的logits参数是**未经激活的wx+b**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "cross_entropy = -tf.reduce_sum(y_ * tf.log(y))  # 交叉熵"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "生成一个训练step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy)  # 梯度下降法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这里我们仍然调用系统提供的读取数据，为我们取得一个batch。。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0 training accuracy 0.06\n",
      "step 2000 training accuracy 1.0\n",
      "step 4000 training accuracy 0.94\n",
      "step 6000 training accuracy 1.0\n",
      "step 8000 training accuracy 1.0\n",
      "step 10000 training accuracy 1.0\n",
      "step 12000 training accuracy 1.0\n",
      "step 14000 training accuracy 0.98\n",
      "step 16000 training accuracy 1.0\n",
      "step 18000 training accuracy 1.0\n",
      "test accuracy 0.993\n"
     ]
    }
   ],
   "source": [
    "# Train\n",
    "#for _ in range(10000):\n",
    "  #batch_xs, batch_ys = mnist.train.next_batch(50)\n",
    "  #sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})\n",
    "init = tf.global_variables_initializer()\n",
    "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))  # 精确度计算    \n",
    "\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    for i in range(20000):\n",
    "        batch = mnist.train.next_batch(50)\n",
    "        if i % 2000 == 0:  # 训练2000次，验证一次\n",
    "            train_acc = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})\n",
    "            print('step', i, 'training accuracy', train_acc)\n",
    "        train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})\n",
    "    test_acc = accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})\n",
    "        \n",
    "print(\"test accuracy\", test_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "验证我们模型在测试数据上的准确率"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "毫无疑问，这个模型是一个非常简陋，性能也不理想的模型。目前只能达到92%左右的准确率。\n",
    "接下来，希望大家利用现有的知识，将这个模型优化至98%以上的准确率。\n",
    "Hint：\n",
    "- 卷积\n",
    "- 池化\n",
    "- 激活函数\n",
    "- 正则化\n",
    "- 初始化\n",
    "- 摸索一下各个超参数\n",
    "  - 卷积kernel size\n",
    "  - 卷积kernel 数量\n",
    "  - 学习率\n",
    "  - 正则化惩罚因子\n",
    "  - 最好每隔几个step就对loss、accuracy等等进行一次输出，这样才能有根据地进行调整"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
